f7f868
@@ -23,13 +23,17 @@
 import java.util.List;
 import java.util.Set;
 
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.ql.Context;
 import org.apache.hadoop.hive.ql.ErrorMsg;
 import org.apache.hadoop.hive.ql.exec.Utilities;
 import org.apache.hadoop.hive.ql.exec.mr.ExecMapper;
-import org.apache.hadoop.hive.ql.io.HiveInputFormat;
+import org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat;
+import org.apache.hadoop.hive.ql.io.CombineHiveInputFormat;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.plan.BaseWork;
 import org.apache.hadoop.hive.ql.plan.MapWork;
@@ -38,6 +42,7 @@
 import org.apache.hadoop.hive.ql.plan.SparkWork;
 import org.apache.hadoop.hive.ql.stats.StatsFactory;
 import org.apache.hadoop.hive.ql.stats.StatsPublisher;
+import org.apache.hadoop.hive.shims.ShimLoader;
 import org.apache.hadoop.io.BytesWritable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableComparable;
@@ -46,6 +51,8 @@
 import org.apache.spark.api.java.JavaSparkContext;
 
 public class SparkPlanGenerator {
+  private static final Log LOG = LogFactory.getLog(SparkPlanGenerator.class);
+
   private JavaSparkContext sc;
   private final JobConf jobConf;
   private Context context;
@@ -86,13 +93,39 @@
public SparkPlan generate(SparkWork sparkWork) throws Exception {
   private JavaPairRDD<BytesWritable, BytesWritable> generateRDD(MapWork mapWork) throws Exception {
     List<Path> inputPaths = Utilities.getInputPaths(jobConf, mapWork, scratchDir, context, false);
     Utilities.setInputPaths(jobConf, inputPaths);
-    Class ifClass = HiveInputFormat.class;
+    Utilities.setMapWork(jobConf, mapWork, scratchDir, true);
+    Class ifClass = getInputFormat(mapWork);
 
     // The mapper class is expected by the HiveInputFormat.
     jobConf.set("mapred.mapper.class", ExecMapper.class.getName());
     return sc.hadoopRDD(jobConf, ifClass, WritableComparable.class, Writable.class);
   }
 
+  private Class getInputFormat(MapWork mWork) throws HiveException {
+    if (mWork.getInputformat() != null) {
+      HiveConf.setVar(jobConf, HiveConf.ConfVars.HIVEINPUTFORMAT, mWork.getInputformat());
+    }
+    String inpFormat = HiveConf.getVar(jobConf, HiveConf.ConfVars.HIVEINPUTFORMAT);
+    if ((inpFormat == null) || (StringUtils.isBlank(inpFormat))) {
+      inpFormat = ShimLoader.getHadoopShims().getInputFormatClassName();
+    }
+
+    if (mWork.isUseBucketizedHiveInputFormat()) {
+      inpFormat = BucketizedHiveInputFormat.class.getName();
+    }
+
+    Class inputFormatClass;
+    try {
+      inputFormatClass = Class.forName(inpFormat);
+    } catch (ClassNotFoundException e) {
+      String message = "Failed to load specified input format class:" + inpFormat;
+      LOG.error(message, e);
+      throw new HiveException(message, e);
+    }
+
+    return inputFormatClass;
+  }
+
   private SparkTran generate(BaseWork bw) throws IOException, HiveException {
     // initialize stats publisher if necessary
     if (bw.isGatheringStats()) {
